import json
import os
import os.path as osp
import re
import requests
import time
import traceback
from pydantic import BaseModel
from copy import deepcopy
from typing import Any, Dict, List, Union

from openai import AzureOpenAI

from .base import BaseVideoEvalDataset, filter_metadata

MULTI_CHOICE_COT_PROMPT = """
Question: {question}
{option_string}

Answer the given multiple-choice question step by step. Begin by explaining your reasoning process clearly. Conclude by stating the final answer using the following format: 'Therefore, the final answer is: $LETTER' (without quotes), where $LETTER is one of the options. Think step by
step before answering."""

OPEN_ENDED_COT_PROMPT = """
Question: {question}

Answer the given question step by step. Begin by explaining your reasoning process clearly. Conclude by stating the final answer using the following format: 'Therefore, the final answer is: 'Answer: $ANSWER' (without quotes), where $ANSWER is the final answer of the question. Think step by step
before answering."""

MULTI_CHOICE_DO_PROMPT = """
Question: {question}
{option_string}

Do not generate any intermediate reasoning process. Answer directly with the option letter from the given choices.
"""

OPEN_ENDED_DO_PROMPT = """
Question: {question}

Do not generate any intermediate reasoning process. Directly output the final short answer.
"""

INSTRUCTION = """Evaluate whether the model's final answer is correct by comparing it to the ground-truth answer provided for the given question.

You should first extract the final answer from the model's response, and then compare the extracted answer with the ground-truth answer to determine its accuracy.
"""

MULTI_CHOICE_INSTRUCTION = INSTRUCTION + "Output your response in the following structured format:\n" + """{
    "extracted_answer": // str value "A" "B" "C" "D" "E", should be a single character
    "correct": // boolean value, True if the answer is correct, False otherwise
}
"""

OPEN_ENDED_INSTRUCTION = INSTRUCTION + """The final answer generated by the model does not need to match the ground-truth answer word-for-word. However, it should only be considered correct if it demonstrates the exact same technique or concept explicitly and unambiguously equivalent to the ground-truth answer.
Output your response in the following structured format:
{
    "extracted_answer": // str value, the short final answer extracted from the model's response, do not hallucinate one that is not present in the response
    "correct": // boolean value, True if the answer is correct, False otherwise
}
"""

class EvaluationOutput(BaseModel):
    extracted_answer: str
    correct: bool


class MMVUDataset(BaseVideoEvalDataset):

    BENCHMARK_TYPE: str = "mcqa"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize Azure OpenAI client with key-based authentication
        self.client = AzureOpenAI(  
            azure_endpoint=os.getenv("ENDPOINT_URL"),
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
            api_version="2024-08-01-preview",
        )

    def load_data(self, data_root: str) -> Dict[int, Any]:
        data_dict = {}

        video_folder = os.path.join(data_root, "videos")
        json_file = os.path.join(data_root, "validation.json")
        with open(json_file, "r") as f:
            data_list = json.load(f)

        idx = 0
        for data in data_list:
            video_path = os.path.join(
                video_folder,
                osp.basename(osp.dirname(data["video"])),
                osp.basename(data["video"])
            )

            data = {
                # required fields for data loading
                "video_path": video_path,
                "start_time": None,
                "end_time": None,
                # required fields for evaluation
                "task_type": data["metadata"]["subject"],
                "ground_truth": data["answer"],
                # custom fields for instruction generation and post processing
                "question": data["question"],
                "question_type": data["question_type"],
                "options": data["choices"],
                "original_id": data["id"],
            }

            data_dict[idx] = {**data, "use_cot": False}
            data_dict[idx + 1] = {**data, "use_cot": True}
            idx += 2

        return data_dict

    def generate_instruction(self, data_id: Union[int, str], video: Any) -> str:
        meta_data = self.data_dict[data_id]

        if meta_data["question_type"] == "multiple-choice":
            options = [f"{key}: {value}" for key, value in meta_data["options"].items()]
            option_string = "\n".join(options)
            template = MULTI_CHOICE_COT_PROMPT if meta_data["use_cot"] else MULTI_CHOICE_DO_PROMPT
            instruction = template.format(question=meta_data["question"], option_string=option_string)
        elif meta_data["question_type"] == "open-ended":
            template = OPEN_ENDED_COT_PROMPT if meta_data["use_cot"] else OPEN_ENDED_DO_PROMPT
            instruction = template.format(question=meta_data["question"])
        else:
            raise ValueError(f"Unknown question type: {meta_data['question_type']}")

        return instruction

    def process_response(self, data_id: Union[int, str], response: str) -> int:
        meta_data = self.data_dict[data_id]
        ground_truth = meta_data["ground_truth"]
        question_type = meta_data["question_type"]

        if question_type == "multiple-choice":
            options = [f"{key}: {value}" for key, value in meta_data["options"].items()]
            option_string = "\n".join(options)
            question_context = f"Question: {meta_data['question']}\n\nOptions:\n{option_string}"

            # matches = re.findall(r"\b[A-E]\b", response.upper())
            # if len(matches) == 0:
            #     raise ValueError(f"Cannot find the answer in the response: {response}")
            # prediction = matches[0]
            prediction = response

        elif question_type == "open-ended":
            question_context = f"Question: {meta_data['question']}"
            prediction = response

        else:
            raise ValueError(f"Unknown question type: {question_type}")

        gt_answer = f"Ground Truth Answer: {ground_truth}"
        model_response = f"Model Response to the Question: {prediction}"
        user_prompt = f"{question_context}\n\n{gt_answer}\n\n{model_response}"

        if question_type == "multiple-choice":
            messages = [
                {"role": "system", "content": MULTI_CHOICE_INSTRUCTION},
                {"role": "user", "content": user_prompt},
            ]
        else:
            messages = [
                {"role": "system", "content": OPEN_ENDED_INSTRUCTION},
                {"role": "user", "content": user_prompt},
            ]

        while True:
            try:
                completion = self.client.beta.chat.completions.parse(  
                    model=os.getenv("DEPLOYMENT_NAME", "gpt-4o-0806") ,
                    messages=messages,
                    temperature=1.0,
                    max_tokens=128,  
                    top_p=1.0,
                    response_format=EvaluationOutput,
                )
                result = completion.choices[0].message.parsed
                if result.correct:
                    prediction = ground_truth
                else:
                    prediction = response
                break
            except:
                traceback.print_exc()
                print(f"Not success! retries remaining...")
                time.sleep(1)

        return prediction

    def evaluate(self, results):
        results_wo_cot, results_w_cot = [], []
        for data in results:
            if self.data_dict[data["data_id"]]["use_cot"]:
                results_w_cot.append(data)
            else:
                results_wo_cot.append(data)

        metrics, infos = {}, {}
        metrics["without_cot"], infos["without_cot"] = super().evaluate(results_wo_cot)
        metrics["with_cot"], infos["with_cot"] = super().evaluate(results_w_cot)
        return metrics, infos
